import numpy as np
import matplotlib.pyplot as plt


def plot_nine_images(images, labels, categories=None):
    if len(images) < 9:
        raise ValueError("图片数量必须大于等于9张")
    images = images[:9]
    labels = labels[:9]
    fig, axes = plt.subplots(3, 3, figsize=(8, 8))
    for i in range(9):
        row = i // 3
        col = i % 3
        img = images[i].numpy().transpose(1, 2, 0)
        axes[row, col].imshow(img)
        title = labels[i].item()
        if categories:
            title = categories[labels[i].item()]
        axes[row, col].set_title(f'{title}')
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()